clc;
clearvars;
load Data;
format short g


%=========== Define simulation parameters ===========%
max_iter = 100;
market_size = 100;
cluster = 1;
text_cluster = num2str(cluster);
text_size = num2str(market_size);
text_iter = num2str(max_iter);


%=========== Define variables to store results ===========%
shareM = NaN(max_iter,9);
shareF = NaN(max_iter,9);
DivorceCosts = NaN(max_iter,12);
HHIndices = cell(max_iter,1);
StabilityIndices = cell(max_iter,1);


if cluster == 0
    text_solver = 'gurobi';
    parnum = 2;
else
    text_solver = 'mosek';
    parnum = 20;
    addpath(genpath("/opt/apps/easybuild/software/YALMIP/R20230609/"))
    addpath(genpath("~/mosek/10.1/toolbox/r2017a"))
end


%=========== Define chores = housework + childcare hours ===========%
Data.chorehm = Data.hworkm + Data.chcarem;
Data.chorehf = Data.hworkf + Data.chcaref;
Data.chorehm(Data.singlefemale == 1) = 0;
Data.chorehf(Data.singlemale == 1) = 0;


%=========== Define povertyline ===========%
Data.consumption = Data.q + Data.Q;
Data.percapitaconsumption(find(Data.couple == 1)) = Data.consumption(find(Data.couple == 1))/2;
Data.percapitaconsumption(find(Data.couple == 0)) = Data.consumption(find(Data.couple == 0));
povertyline = 0.6*median(Data.percapitaconsumption);


%=========== Define age restrictions for marriage market ===========%
age_diff = Data.agem-Data.agef;
age_diff = age_diff(Data.couple==1);
lb = prctile(age_diff,2.5);
ub = prctile(age_diff,97.5);


%===========  Define age, education and children categories for preference types ===========%
for i = 1:size(Data,1)
    if isnan(Data.agem(i))
        Data.agecatm(i) = 0;
    elseif 1 < Data.agem(i) &&  Data.agem(i) <= 35
        Data.agecatm(i) = 1;
    elseif 35 < Data.agem(i) &&  Data.agem(i) <= 50
        Data.agecatm(i) = 2;
    elseif 50 < Data.agem(i)
        Data.agecatm(i) = 3;
    end

    if isnan(Data.agef(i))
        Data.agecatf(i) = 0;
    elseif 1 < Data.agef(i) &&  Data.agef(i) <= 35
        Data.agecatf(i) = 1;
    elseif 35 < Data.agef(i) &&  Data.agef(i) <= 50
        Data.agecatf(i) = 2;
    elseif 50 < Data.agef(i)
        Data.agecatf(i) = 3;
    end

    if isnan(Data.edum(i))
        Data.educatm(i) = 0;
    else
        Data.educatm(i) = Data.edum(i);
    end

    if isnan(Data.eduf(i))
        Data.educatf(i) = 0;
    else
        Data.educatf(i) = Data.eduf(i);
    end

    if isnan(Data.nchild(i))
        Data.kidscat(i) = 0;
    elseif Data.nchild(i) == 0
        Data.kidscat(i) = 1;
    elseif Data.nchild(i) > 0
        Data.kidscat(i) = 2;
    end

    Data.typem(i) = Data.educatm(i)*100 + Data.kidscat(i)*10 + Data.agecatm(i);
    Data.typef(i) = Data.educatf(i)*100 + Data.kidscat(i)*10 + Data.agecatf(i);

    if Data.singlemale(i) == 1
        Data.typef(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.typem(i) = 0;
    end
end


%=========== Define household types and weights ===========%
Data.types = Data.typem.*1000 + Data.typef;
[C,~,ic] = unique(Data.types);
a_counts = accumarray(ic,1);
value_counts = [C, a_counts];
value_counts(:,3) = value_counts(:,2)./sum(value_counts(:,2));


%=========== Define couple type based on education and employment status of spouses ===========%
for i = 1:size(Data,1)
    if Data.couple(i) == 1
        Data.hcatm(i) = 10*(1+Data.male_emp(i)) + Data.educatm(i);
        Data.hcatf(i) = 10*(1+Data.female_emp(i)) + Data.educatf(i);
        Data.hcat(i) = Data.hcatm(i)*100 + Data.hcatf(i);
    elseif Data.singlemale(i) == 1
        Data.hcat(i) = 0;
    elseif Data.singlefemale(i) == 1
        Data.hcat(i) = 0;
    end
end



%%
%=========== Estimate the model with subsamples ===========% 
sc = parallel.pool.Constant(RandStream('Threefry'));


parfor (iter = 1:max_iter,parnum)

    stream = sc.Value;
    stream.Substream = iter;

    complete = 0;
    while complete == 0

        %=========== randomly draw household types from their weighted distribution ===========% 
        rand_index = randsample(stream,value_counts(:,1), market_size, true, value_counts(:,3));
        [C,~,ic] = unique(rand_index);
        a_counts = accumarray(ic,1);
        index_counts = [C, a_counts];

        %=========== Given the number of household types required, randomly draw households of that type from the observed households ===========%
        Data_r = [];
        for i = 1:size(index_counts,1)
            Extract = Data(find(Data.types == index_counts(i,1)),:);
            Extract_index = randsample(stream,size(Extract,1),index_counts(i,2),true);
            Data_r = [Data_r; Extract(Extract_index,:)];
        end

       %=========== Define consideration set ===========% 
       consideration_set = get_consideration_sets(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,Data_r.agem,Data_r.agef,lb,ub);
       
       %=========== Compute stability indices ===========% 
        fprintf('Computing Divorce Costs...\n');
        [temp_DivorceCostsM,temp_DivorceCostsF,temp_DivorceCostsMF,solved] = get_indices(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
            Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
            Data_r.nonlabor,consideration_set,text_solver);

        %=========== Compute RICEBs ===========% 
        if solved == 0
            fprintf('Computing RICEBs...\n');
            [temp_min_male,temp_max_male,temp_min_female,temp_max_female,problem] = get_ricebs(Data_r.couple,Data_r.singlemale,Data_r.singlefemale,...
                Data_r.leisurem,Data_r.leisuref,Data_r.chorehm,Data_r.chorehf,Data_r.wagem,Data_r.wagef,Data_r.q,Data_r.Q,...
                Data_r.nonlabor,consideration_set,temp_DivorceCostsM+eps,temp_DivorceCostsF+eps,temp_DivorceCostsMF+eps,Data_r.hcat,text_solver);

            if sum(~isnan(temp_min_male(:))) && sum(~isnan(temp_max_male(:))) && sum(~isnan(temp_min_female(:))) && sum(~isnan(temp_max_female(:))) && problem == 0

                %=========== Store stability indices ===========%     
                ncouple = length(find(Data_r.couple));
                temp_DivorceCostsMF(~consideration_set) = NaN;
                DC_temp = temp_DivorceCostsMF;

                DivorceCosts_maxm = max(DC_temp,[],2);
                DivorceCosts_maxm = DivorceCosts_maxm(find(Data_r.couple));
                DivorceCosts_maxf = max(DC_temp)';
                DivorceCosts_maxf = DivorceCosts_maxf(find(Data_r.couple));
                DivorceCosts_avgf = (nansum(DC_temp)')./(sum(~isnan(DC_temp))');
                DivorceCosts_avgf = DivorceCosts_avgf(find(Data_r.couple));
                DivorceCosts_avgm = (nansum(DC_temp,2))./(sum(~isnan(DC_temp),2));
                DivorceCosts_avgm = DivorceCosts_avgm(find(Data_r.couple));

                DivorceCosts_temp = zeros(ncouple,9);
                DivorceCosts_temp(:,1) = (temp_DivorceCostsM(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,2) = (temp_DivorceCostsF(find(Data_r.couple)))./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,3) = DivorceCosts_maxm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,4) = DivorceCosts_maxf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,5) = DivorceCosts_avgm./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,6) = DivorceCosts_avgf./(Data_r.consumption(find(Data_r.couple)));
                DivorceCosts_temp(:,7) = iter;
                mmsize = nansum(consideration_set,2);
                DivorceCosts_temp(:,8) = mmsize(find(Data_r.couple));
                mmsize = nansum(consideration_set)';
                DivorceCosts_temp(:,9) = mmsize(find(Data_r.couple));

                NBPStable = DivorceCosts_temp(:,5) <= 0.00001 & DivorceCosts_temp(:,6) <= 0.00001;
                IRMStable = DivorceCosts_temp(:,1) <= 0.00001;
                IRFStable = DivorceCosts_temp(:,2) <= 0.00001;                
                DivorceCosts(iter,:) = [mean(DivorceCosts_temp), mean(NBPStable), mean(IRMStable), mean(IRFStable)];

                temp_DivorceCostsM(find(~Data_r.couple)) = NaN;
                temp_DivorceCostsF(find(~Data_r.couple)) = NaN;
                DC_temp = [DC_temp, temp_DivorceCostsM];
                DC_temp = [DC_temp; [temp_DivorceCostsF',0]];
                StabilityIndices{iter} = DC_temp;
                HHIndices{iter} = Data_r.familyid;

                %=========== Store RICEBs ===========%  
                male_temp = zeros(1,2*size(temp_min_male,1)+1);
                for i = 1:size(temp_min_male,1)
                    male_temp(1,2*i-1) = temp_min_male(i);
                    male_temp(1,2*i) = temp_max_male(i);
                end
                male_temp(1,2*size(temp_min_male,1)+1) = iter;
                shareM(iter,:) = male_temp;

                female_temp = zeros(1,2*size(temp_min_female,1)+1);
                for i = 1:size(temp_min_female,1)
                    female_temp(1,2*i-1) = temp_min_female(i);
                    female_temp(1,2*i) = temp_max_female(i);
                end
                female_temp(1,2*size(temp_min_female,1)+1) = iter;
                shareF(iter,:) = female_temp;

                complete = 1;
                fprintf('*********** Iteration %d finished *************\n', iter);
            end

        end
    end
end

filename = ['DataPSID_results_ricebs_4cat_size',text_size,'_iter',text_iter,'_Cluster',text_cluster];
save(filename)

%============== TABLE 6 =======================%
fprintf('Bounds for male employed = no and education = low = [%.4f, %.4f] \n', mean(shareM(:,1), "omitnan"), mean(shareM(:,2), "omitnan"));
fprintf('Bounds for male employed = no and education = high = [%.4f, %.4f] \n', mean(shareM(:,3), "omitnan"), mean(shareM(:,4), "omitnan"));
fprintf('Bounds for male employed = yes and education = low = [%.4f, %.4f] \n', mean(shareM(:,5), "omitnan"), mean(shareM(:,6), "omitnan"));
fprintf('Bounds for male employed = yes and education = high = [%.4f, %.4f] \n', mean(shareM(:,7), "omitnan"), mean(shareM(:,8), "omitnan"));

fprintf('Bounds for female employed = no and education = low = [%.4f, %.4f] \n', mean(shareF(:,1), "omitnan"), mean(shareF(:,2), "omitnan"));
fprintf('Bounds for female employed = no and education = high = [%.4f, %.4f] \n', mean(shareF(:,3), "omitnan"), mean(shareF(:,4), "omitnan"));
fprintf('Bounds for female employed = yes and education = low = [%.4f, %.4f] \n', mean(shareF(:,5), "omitnan"), mean(shareF(:,6), "omitnan"));
fprintf('Bounds for female employed = yes and education = high = [%.4f, %.4f] \n', mean(shareF(:,7), "omitnan"), mean(shareF(:,8), "omitnan"));


%===================== TABLE 15 =========================%
fprintf("IR male: %.2f \n", mean(DivorceCosts(:,1),"omitnan")*100);
fprintf("IR female: %.2f \n", mean(DivorceCosts(:,2),"omitnan")*100);
fprintf("NBP max male: %.2f \n", mean(DivorceCosts(:,3),"omitnan")*100);
fprintf("NBP max female: %.2f \n", mean(DivorceCosts(:,4),"omitnan")*100);
fprintf("NBP avg male: %.2f \n", mean(DivorceCosts(:,5),"omitnan")*100);
fprintf("NBP avg female: %.2f \n", mean(DivorceCosts(:,6),"omitnan")*100);


%=========== Summarize Percentage Stability Indices Below Cut-off ===========%
NBPIndices = cell(max_iter,1);
IRMIndices = cell(max_iter,1);
IRFIndices = cell(max_iter,1);

for i = 1:max_iter
    NBPIndices{i} = StabilityIndices{i}(1:100, 1:100);
    IRMIndices{i} = StabilityIndices{i}(1:100, 101);
    IRFIndices{i} = (StabilityIndices{i}(101,1:100))';
end    


cutoff = [1e-6, 0.0001, 0.0005, 0.01, 0.05, 0.1, 0.5, 1, 5, 10, 15];
IndicesBelow = NaN(length(cutoff),6);

for j = 1:length(cutoff)
    result = NaN(max_iter,5);

    for i = 1:max_iter

        result1 = sum(IRMIndices{i} <= cutoff(j) & ~isnan(IRMIndices{i}))/sum(~isnan(IRMIndices{i}));
        result2 = sum(IRFIndices{i} <= cutoff(j) & ~isnan(IRFIndices{i}))/sum(~isnan(IRFIndices{i}));
        result3 = sum(NBPIndices{i}(:) <= cutoff(j) & ~isnan(NBPIndices{i}(:)))/sum(~isnan(NBPIndices{i}(:)));


        NBPm_max = max(NBPIndices{i},[],2);
        NBPf_max = max(NBPIndices{i})';
        findcouple = ~isnan(NBPm_max) & ~isnan(NBPf_max);
        NBP_max = max(NBPm_max(findcouple), NBPf_max(findcouple));       

        
        NBPm_sum = sum(NBPIndices{i},2,"omitnan");        
        NBPm_num = sum(~isnan(NBPIndices{i}),2);
        
        NBPf_sum = sum(NBPIndices{i},"omitnan")';
        NBPf_num = sum(~isnan(NBPIndices{i}))';
        
        NBP_sum = NBPm_sum(findcouple) + NBPf_sum(findcouple);
        NBP_num = NBPm_num(findcouple) + NBPf_num(findcouple);
        NBP_avg = NBP_sum./NBP_num;
     
        result4 = sum(NBP_max <= cutoff(j))/length(NBP_max);
        result5 = sum(NBP_avg <= cutoff(j))/length(NBP_avg);
        
        result(i,:) = [result1, result2, result3, result4, result5];
    end
    IndicesBelow(j,:) = [mean(result), cutoff(j)];
end


%================ TABLE 16 ====================%
display(IndicesBelow)




